Attention Mechanism

Tao Zou

2024-03-04

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

Input Data

Source Target
I am from China 我来自中国
You and me are best friends 你我是最好的朋友

batch_size=2, num_steps=8, “<unk>”=0, “<pad>” = 1, “<bos>”=2, “<eos>”=3.

‘I am from China’ -> \([4, 5, 6, 7, 1, 1, 1, 1]\)

‘You and me are best friends’ -> \([8, 9, 10, 11, 12, 13, 1, 1]\)

\[X:\begin{bmatrix}4&5&6&7&1&1\\8&9&10&11&12&13\end{bmatrix}\ \ X\_valid\_len:\begin{bmatrix}4&6\end{bmatrix}\]

‘我来自中国’ -> \([4, 5, 6, 7, 8, 1, 1, 1]\)

‘你我是最好的朋友’ -> \([9, 4, 11, 12, 13, 14, 15, 16]\)

\[Y:\begin{bmatrix}4&5&6&7&8&1&1&1&1\\9&10&4&11&12&13&14&15&16\end{bmatrix}\ \ Y\_valid\_len:\begin{bmatrix}5&8\end{bmatrix}\]

Basic Functions

sequence_mask

def sequence_mask(X, valid_len, value=0.0):
    '''
    :param X: (batch_size, seq_len, input_dim)
    :param valid_len: (batch_size, )                          or (batch_size, seq_len)  <I will discuss this!>

    (query_lens, num_hiddens) * (key_lens, num_hiddens)^T = (query_lens, key_lens)
    :param X: (batch_size * query_lens, num_hiddens)
    :param valid_len: (batch_size, ) --torch.repeat_interleave()--> valid_lens: (batch_size*query_lens, )
    '''
    maxlen = X.shape[1]
    mask = torch.arange(maxlen, dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X
X = torch.ones(2, 6, 8)  # shape: (batch_size, seq_len, input_dim)
valid_len = torch.tensor([4, 6]).reshape(2, )  # shape: (batch_size, )
print(sequence_mask(X, valid_len, -99))
## tensor([[[  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.],
##          [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.],
##          [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.],
##          [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.],
##          [-99., -99., -99., -99., -99., -99., -99., -99.],
##          [-99., -99., -99., -99., -99., -99., -99., -99.]],
## 
##         [[  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.],
##          [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.],
##          [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.],
##          [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.],
##          [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.],
##          [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.]]])

(1, seq_len) < (batch_size, 1) \[ \begin{bmatrix}\begin{bmatrix}0&1&2\end{bmatrix}\end{bmatrix} < \begin{bmatrix}\begin{bmatrix}1\end{bmatrix}\\\begin{bmatrix}2\end{bmatrix}\end{bmatrix}\\ \downarrow\\ \begin{bmatrix}\begin{bmatrix}0&1&2\end{bmatrix}\\\begin{bmatrix}0&1&2\end{bmatrix}\end{bmatrix} < \begin{bmatrix}\begin{bmatrix}1&1&1\end{bmatrix}\\\begin{bmatrix}2&2&2\end{bmatrix}\end{bmatrix}\\ \downarrow\\ mask:\begin{bmatrix}\begin{bmatrix}True&False&False\end{bmatrix}\\\begin{bmatrix}True&True&False\end{bmatrix}\end{bmatrix} \]

masked_softmax

def masked_softmax(X, valid_lens):
    '''
    query: (2, 6, 14) * key: (2, 8, 14)^T = score: (2, 6, 8)
    score: (2, 6, 8) * value: (2, 8, 14) = (2, 6, 14)
    masked_softmax() is used to mask score.

    :param X: (batch_size, query_lens, key_lens)
    :param valid_lens: (batch_size, )
    '''
    if valid_lens is None:
        return F.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:  # I will discuss this after!
            valid_lens = valid_lens.reshape(-1)
        X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return F.softmax(X.reshape(shape), dim=-1)
masked_softmax(torch.rand(2, 6, 8), torch.tensor([4, 6]))
## tensor([[[0.2318, 0.2245, 0.2486, 0.2951, 0.0000, 0.0000, 0.0000, 0.0000],
##          [0.3212, 0.2489, 0.2753, 0.1546, 0.0000, 0.0000, 0.0000, 0.0000],
##          [0.2987, 0.1641, 0.1543, 0.3829, 0.0000, 0.0000, 0.0000, 0.0000],
##          [0.2037, 0.1553, 0.2646, 0.3764, 0.0000, 0.0000, 0.0000, 0.0000],
##          [0.2201, 0.2857, 0.2216, 0.2725, 0.0000, 0.0000, 0.0000, 0.0000],
##          [0.2362, 0.2553, 0.3232, 0.1853, 0.0000, 0.0000, 0.0000, 0.0000]],
## 
##         [[0.2086, 0.2447, 0.1044, 0.1487, 0.1750, 0.1186, 0.0000, 0.0000],
##          [0.1459, 0.1679, 0.2042, 0.1070, 0.1879, 0.1871, 0.0000, 0.0000],
##          [0.1071, 0.1046, 0.1502, 0.2344, 0.2091, 0.1947, 0.0000, 0.0000],
##          [0.1334, 0.0907, 0.0978, 0.2194, 0.2305, 0.2282, 0.0000, 0.0000],
##          [0.1899, 0.1082, 0.2001, 0.1261, 0.2665, 0.1093, 0.0000, 0.0000],
##          [0.2057, 0.1857, 0.1246, 0.1821, 0.1901, 0.1119, 0.0000, 0.0000]]])

DotProductAttention

class DotProductAttention(nn.Module):
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        '''
        :param queries: (batch_size, query_lens, num_hiddens)
        :param keys: (batch_size, key_lens, num_hiddens)
        :param values: (batch_size, value_lens, num_hiddens)
        '''
        d = queries.shape[-1]
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)  # O(n*n*d) self-Attention
        self.attention_weights = masked_softmax(scores, valid_lens)  # O(n*n) self-Attention
        return torch.bmm(self.dropout(self.attention_weights), values)  # O(n*n*d) self-Attention
queries, keys, values = torch.normal(0, 1, (2, 6, 14)), torch.normal(0, 1, (2, 6, 14)), torch.normal(0, 1, (2, 6, 14))
attention = DotProductAttention(dropout=0.5)
attention.eval()
## DotProductAttention(
##   (dropout): Dropout(p=0.5, inplace=False)
## )
print(attention(queries, keys, values, torch.tensor([3, 4])).shape)
## torch.Size([2, 6, 14])

MultiHeadAttention

Suppose an input matrix of dimension (1, seq_lens=6, input_size=14), and number of heads is 2. Head1 processes the red area, and head2 processes the blue area. transpose_qkv function will transpose \((1, 6, 14)\) into \((1*2, 6, 7)\) to facilitate the parallelized computation. transpose_output function will turn it into its original form.

\[\begin{bmatrix}\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1&\color{red}0&\color{red}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1&\color{red}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{blue}1&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{red}0&\color{red}1&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{red}0&\color{red}1&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\end{bmatrix}\mathop{\longrightarrow}^{transpose\_qkv()}\begin{matrix}\begin{bmatrix}\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1&\color{red}0&\color{red}0\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1&\color{red}0\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}1\\\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0\\\color{red}0&\color{red}1&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0\\\color{red}0&\color{red}1&\color{red}0&\color{red}0&\color{red}0&\color{red}0&\color{red}0\end{bmatrix}\\ \begin{bmatrix}\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{blue}1&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\\\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0&\color{blue}0\end{bmatrix} \end{matrix}\]

def transpose_qkv(X, num_heads):
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    X = X.permute(0, 2, 1, 3)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

class MultiHeadAttention(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
        output = self.attention(queries, keys, values, valid_lens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

Positional encoding

\[\boldsymbol{Embed\_X}+\boldsymbol{P}=\begin{bmatrix}x_{11}&x_{12}&\cdots&x_{1d}\\x_{21}&x_{22}&\cdots&x_{2d}\\\vdots&\vdots&\ddots&\vdots\\x_{n1}&x_{n2}&\cdots&x_{nd}\end{bmatrix}+\begin{bmatrix}p_{11}&x_{12}&\cdots&p_{1d}\\p_{21}&p_{22}&\cdots&p_{2d}\\\vdots&\vdots&\ddots&\vdots\\p_{n1}&p_{n2}&\cdots&p_{nd}\end{bmatrix}\], where \(n\) represents seq_lens, \(d\) represents embedding size, \(\boldsymbol{P}\) is the positional encoding matrix.

\[p_{i, 2j}=\sin\Big{(}\frac{i}{10000^{2j/d}}\Big{)}\ \ \ \ \ \ \ p_{i, 2j+1}=\cos\Big{(}\frac{i}{10000^{2j/d}}\Big{)}\], where \(i=0,1,\cdots,n-1\) and \(j=0, 1, \cdots,d/2-1\).

Properties:

\[\begin{bmatrix}p_{i+k, 2j}\\p_{i+k, 2j+1}\end{bmatrix}=\begin{bmatrix}\cos\Big{(}\frac{k}{10000^{2j/d}}\Big{)}&\sin\Big{(}\frac{k}{10000^{2j/d}}\Big{)}\\-\sin\Big{(}\frac{k}{10000^{2j/d}}\Big{)}&\cos\Big{(}\frac{k}{10000^{2j/d}}\Big{)}\end{bmatrix}\begin{bmatrix}p_{i,2j}\\p_{i,2j+1}\end{bmatrix}\]

\[\begin{bmatrix}p_{i,0}&p_{i,1}&\cdots&p_{i,d-1}\end{bmatrix}\begin{bmatrix}p_{i+k,0}\\p_{i+k, 1}\\\vdots\\p_{i+k, d-1}\end{bmatrix}=\cos\Big{(}\frac{k}{1000^{0/d}}\Big{)}+\cos\Big{(}\frac{k}{10000^{2/d}}\Big{)}+\cdots+\cos\Big{(}\frac{k}{10000^{(d-2)/d}}\Big{)}\]

Assuming \(d=512\), the inner product of vectors as \(k\) increases is shown below.

import plotly.express as px
import pandas as pd
import numpy as np

def myfunc(k, d):
    exponents = np.arange(0, d, 2)/d
    a = k / np.power(10000, exponents)
    a = np.cos(a)
    # print(a)
    return np.sum(a)
k = np.arange(0, 200)
y = [myfunc(item, 512) for item in k]
df = pd.DataFrame({'k': k, 'y': y})
fig = px.scatter(df, x='k', y='y', width=768, height=474)
fig.show()
class PositionalEncoding(nn.Module):
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = (torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / 
             torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens))
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)
    
    def forward(self, X):
        X += self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

Add&Norm

Suppose I have an input of dimension \((batch\_size=2, seq\_lens=2, input\_size=4)\).

\[ \begin{bmatrix}x^{(1)}_{11}&x^{(1)}_{12}&x^{(1)}_{13}&x^{(1)}_{14}\\x^{(1)}_{21}&x^{(1)}_{22}&x^{(1)}_{23}&x^{(1)}_{24}\end{bmatrix}\\ \begin{bmatrix}x^{(2)}_{11}&x^{(2)}_{12}&x^{(2)}_{13}&x^{(2)}_{14}\\x^{(2)}_{21}&x^{(2)}_{22}&x^{(2)}_{23}&x^{(2)}_{24}\end{bmatrix} \]

The normalization operator nn.LayerNorm(4) is applied on every token.

\[ mean^{(1)}_1=\frac{1}{4}\sum_j^4x^{(1)}_{1j}\\Var^{(1)}_1=\frac{1}{4}\sum_j\big{(}x^{(1)}_{1j}-mean_1^{(1)}\big{)}^2 \]

The normalization operator nn.LayerNorm([2, 4]) is applied on every input text.

\[ mean^{(1)}=\frac{1}{2\times4}\sum^2_i\sum^4_jx^{(1)}_{ij}\\Var^{(1)}=\frac{1}{2\times4}\sum_i^2\sum_j^4\big{(}x^{(1)}_{ij}-mean^{(1)}\big{)}^2 \]

ln1 = nn.LayerNorm(4)
ln2 = nn.LayerNorm([2, 4])
with torch.no_grad():
    # X shape: (2, 2, 4)
    X = torch.tensor([
        [[1, 2, 3, 4],
         [5, 6, 7, 8]],
        [[5, 6, 7, 8],
         [5, 1, 0, -1]]
    ], dtype=torch.float32)
    print(ln1(X))
    print(ln2(X))
## tensor([[[-1.3416, -0.4472,  0.4472,  1.3416],
##          [-1.3416, -0.4472,  0.4472,  1.3416]],
## 
##         [[-1.3416, -0.4472,  0.4472,  1.3416],
##          [ 1.6465, -0.1098, -0.5488, -0.9879]]])
## tensor([[[-1.5275, -1.0911, -0.6547, -0.2182],
##          [ 0.2182,  0.6547,  1.0911,  1.5275]],
## 
##         [[ 0.3538,  0.6683,  0.9829,  1.2974],
##          [ 0.3538, -0.9042, -1.2187, -1.5332]]])
class AddNorm(nn.Module):
    def __init__(self, normalized_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)
    
    def forward(self, X, Y):
        return self.ln(X + self.dropout(Y))

ForwardWiseFFN

class PositionWiseFFN(nn.Module):
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

Encoder&Decoder

Encoder

class EncoderBlock(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, dropout, use_bias=False, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.attention = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm2 = AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))

class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape,
                 ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, use_bias=False, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block" + str(i),
                                 EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape,
                                              ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias))

    def forward(self, X, valid_lens):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))  # I haven't dived into this line.
        self.attention_weights = [None] * len(self.blks)  # self.attention_weights is the score matrix
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
            self.attention_weights[i] = blk.attention.attention.attention_weights
        return X
encoder = TransformerEncoder(200, 14, 14, 14, 14, [6, 14], 14, 28, 2, 6, 0.5)
#encoder.eval()
X = torch.ones((2, 6), dtype=torch.long)
valid_lens = torch.tensor([4, 6], dtype=torch.long)
print(encoder(X, valid_lens).shape)
## torch.Size([2, 6, 14])

Decoder

class DecoderBlock(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape,
                 ffn_num_input, ffn_num_hiddens, num_heads, dropout, i, **kwargs):
        super(DecoderBlock, self).__init__(**kwargs)
        self.i = i
        self.attention1 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.attention2 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm2 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm3 = AddNorm(norm_shape, dropout)

    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        if state[2][self.i] is None:
            key_values = X
        else:
            key_values = torch.cat((state[2][self.i], X), dim=1)
        state[2][self.i] = key_values
        if self.training:
            batch_size, num_steps, _ = X.shape
            dec_valid_lens = torch.arange(1, num_steps+1, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None

        X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
        Y = self.addnorm1(X, X2)
        Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
        Z = self.addnorm2(Y, Y2)
        return self.addnorm3(Z, self.ffn(Z)), state

class AttentionDecoder(nn.Module):
    def __init__(self, **kwargs):
        super(AttentionDecoder, self).__init__(**kwargs)

    def attention_weights(self):
        raise NotImplementedError
      
      
class TransformerDecoder(AttentionDecoder):
    def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, 
                 ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, **kwargs):
        super(TransformerDecoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.num_layers = num_layers
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block"+str(i),
                DecoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape,
                             ffn_num_input, ffn_num_hiddens, num_heads, dropout, i))
        self.dense = nn.Linear(num_hiddens, vocab_size)
    
    def init_state(self, enc_outputs, enc_valid_lens, *args):
        return [enc_outputs, enc_valid_lens, [None] * self.num_layers]
    
    def forward(self, X, state):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        self._attention_weights = [[None] * len(self.blks) for _ in range(2)]
        for i, blk in enumerate(self.blks):
            X, state = blk(X, state)
            self._attention_weights[0][i] = blk.attention1.attention.attention_weights
            self._attention_weights[1][i] = blk.attention2.attention.attention_weights
        return self.dense(X), state
    
    @property
    def attention_weights(self):
        return self._attention_weights

the first multi-head layer in decoder

query after being transposed by multi-heads: \((4, 8, 7)\).

key after being transposed by multi-heads: \((4, 8, 7)\rightarrow^T(4, 7, 8)\).

\[ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times8}\rightarrow\begin{bmatrix}1&0&0&0&0&0&0&0\\s_{21}&s_{22}&0&0&0&0&0&0\\s_{31}&s_{32}&s_{33}&0&0&0&0&0\\s_{41}&s_{42}&s_{43}&s_{44}&0&0&0&0\\s_{51}&s_{52}&s_{53}&s_{54}&s_{55}&0&0&0\\s_{61}&s_{62}&s_{63}&s_{64}&s_{65}&s_{66}&0&0\\s_{71}&s_{72}&s_{73}&s_{74}&s_{75}&s_{76}&s_{77}&0\\s_{81}&s_{82}&s_{83}&s_{84}&s_{85}&s_{86}&s_{87}&s_{88}\end{bmatrix}, \forall i=2,\cdots8; \sum_j s_{ij}=1\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times8}\rightarrow\cdots\cdots\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times8}\rightarrow\cdots\cdots\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times8}\rightarrow\cdots\cdots\\ \]

The above score matrix indicates that the k-th token can only compute self-attention with the previous k tokens.

the second multi-head layer in decoder

query after being transposed by multi-heads:\((4, 8, 7)\).

key after being transposed by multi-heads:\((4, 7, 6)\rightarrow^T(4, 7, 6)\).

\[ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\vdots\\0\cdots0\\0\cdots0\end{bmatrix}_{7\times6}\rightarrow\begin{bmatrix}\cdots&0&0\\\ddots&\vdots&\vdots\\\cdots&0&0\end{bmatrix}_{8\times6}\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times6}\rightarrow\cdots\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times6}\rightarrow\cdots\\ \begin{bmatrix}\end{bmatrix}_{8\times7}\cdot\begin{bmatrix}\end{bmatrix}_{7\times6}\rightarrow\cdots \]

The above score matrix indicates that the invalid tokens in encoder outputs is ignored.

Model

class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
    def forward(self, enc_X, dec_X, *args):
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(dec_X, dec_state)
    
encoder = TransformerEncoder(len(src_vocab), key_size, query_size, value_size, num_hiddens, norm_shape,
                             ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
decoder = TransformerDecoder(len(tgt_vocab), key_size, query_size, value_size, num_hiddens, norm_shape,
                             ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
model = EncoderDecoder(encoder, decoder)

Train And Prediction

Masked softmax loss

class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
    def forward(self, pred, label, valid_len):
        '''
        :pred's shape: (batch_size, num_steps, vocab_size)
        :label's shape: (batch_size, num_steps)
        :valid_len's shape: (batch_size, )
        '''
        weights = torch.ones_like(label)
        weights = sequence_mask(weights, valid_len)
        self.reduction = 'none'
        unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(
            pred.permute(0, 2, 1), label)
        # Above is the correct code for calculating cross entropy loss when pred and label have batch dimension.
        weighted_loss = (unweighted_loss * weights).mean(dim=1)
        return weighted_loss  # Each sequence has a loss value.

Train function

def train_seq2seq(model, data_iter, lr, num_epochs, tgt_vocab, device):
    def xavier_init_weights(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
        if type(m) == nn.GRU:
            for param in m._flat_weights_names:
                if "weight" in param:
                    nn.init.xavier_uniform_(m._parameters[param])
    model.apply(xavier_init_weights)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss = MaskedSoftmaxCELoss()

    model.train()
    for epoch in range(num_epochs):
        myloss = 0
        for batch in data_iter:
            optimizer.zero_grad()
            X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
            # below is called teacher forcing
            bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0], device=device).reshape(-1, 1)
            dec_input = torch.cat([bos, Y[:, :-1]], 1)

            Y_hat, _ = model(X, dec_input, X_valid_len)
            l = loss(Y_hat, Y, Y_valid_len)
            l.sum().backward()
            # d2l.grad_clipping(net, 1)
            num_tokens = Y_valid_len.sum()
            optimizer.step()
            myloss += l.sum() / num_tokens
        if (epoch + 1) % 10 == 0:
            print("loss: {:.4f}".format(myloss))

Prediction

def truncate_pad(line, num_steps, padding_token):
    if len(line) > num_steps:
        return line[:num_steps]
    else:
        return line + [padding_token] * (num_steps - len(line))
def predict_seq2seq(model, src_sentence, src_vocab, tgt_vocab, num_steps, device, save_attention_weights=False):
    model.eval()
    src_tokens = src_vocab[src_sentence.lower().split(' ')] + [src_vocab['<eos>']]
    enc_valid_len = torch.tensor([len(src_tokens)], device=device)
    src_tokens = truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])
    
    enc_X = torch.unsqueeze(torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)
    enc_outputs = model.encoder(enc_X, enc_valid_len)
    
    dec_state = model.decoder.init_state(enc_outputs, enc_valid_len)
    dec_X = torch.unsqueeze(torch.tensor([tgt_vocab['<bos>']], dtype=torch.long, device=device), dim=0)
    
    output_seq, attention_weight_seq = [], []
    for _ in range(num_steps):
        Y, dec_state = model.decoder(dec_X, dec_state)
        dec_X = Y.argmax(dim=2)
        pred = dec_X.squeeze(dim=0).type(torch.int32).item()
        
        if save_attention_weights:
            attention_weight_seq.append(model.decoder.attention_weights)
        
        if pred == tgt_vocab['<eos>']:
            break
        output_seq.append(pred)
    
    return ' '.join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq

Transformer Code

The following code is a copy of all aforementioned Transformer architecture code, convenient for copying.

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

def sequence_mask(X, valid_len, value=0.0):
    maxlen = X.shape[1]
    mask = torch.arange(maxlen, dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X
  
def masked_softmax(X, valid_lens):
    if valid_lens is None:
        return F.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:  # I will discuss this after!
            valid_lens = valid_lens.reshape(-1)
        X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return F.softmax(X.reshape(shape), dim=-1)

class DotProductAttention(nn.Module):
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):

        d = queries.shape[-1]
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

def transpose_qkv(X, num_heads):
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    X = X.permute(0, 2, 1, 3)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

class MultiHeadAttention(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
        output = self.attention(queries, keys, values, valid_lens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

class PositionalEncoding(nn.Module):
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = (torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / 
             torch.pow(10000, torch.arange(0, num_hiddens, 2, dtype=torch.float32) / num_hiddens))
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)
    
    def forward(self, X):
        X += self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

class AddNorm(nn.Module):
    def __init__(self, normalized_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)
    
    def forward(self, X, Y):
        return self.ln(X + self.dropout(Y))

class PositionWiseFFN(nn.Module):
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

class EncoderBlock(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, dropout, use_bias=False, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.attention = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm2 = AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):
        Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
        return self.addnorm2(Y, self.ffn(Y))

class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape,
                 ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, use_bias=False, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block" + str(i),
                                 EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape,
                                              ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias))

    def forward(self, X, valid_lens):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))  # I haven't dived into this line.
        self.attention_weights = [None] * len(self.blks)  # self.attention_weights is the score matrix
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
            self.attention_weights[i] = blk.attention.attention.attention_weights
        return X

class DecoderBlock(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape,
                 ffn_num_input, ffn_num_hiddens, num_heads, dropout, i, **kwargs):
        super(DecoderBlock, self).__init__(**kwargs)
        self.i = i
        self.attention1 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.attention2 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm2 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm3 = AddNorm(norm_shape, dropout)

    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        if state[2][self.i] is None:
            key_values = X
        else:
            key_values = torch.cat((state[2][self.i], X), dim=1)
        state[2][self.i] = key_values
        if self.training:
            batch_size, num_steps, _ = X.shape
            dec_valid_lens = torch.arange(1, num_steps+1, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None

        X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
        Y = self.addnorm1(X, X2)
        Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
        Z = self.addnorm2(Y, Y2)
        return self.addnorm3(Z, self.ffn(Z)), state

class AttentionDecoder(nn.Module):
    def __init__(self, **kwargs):
        super(AttentionDecoder, self).__init__(**kwargs)

    def attention_weights(self):
        raise NotImplementedError
      
      
class TransformerDecoder(AttentionDecoder):
    def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, 
                 ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, **kwargs):
        super(TransformerDecoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.num_layers = num_layers
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block"+str(i),
                DecoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape,
                             ffn_num_input, ffn_num_hiddens, num_heads, dropout, i))
        self.dense = nn.Linear(num_hiddens, vocab_size)
    
    def init_state(self, enc_outputs, enc_valid_lens, *args):
        return [enc_outputs, enc_valid_lens, [None] * self.num_layers]
    
    def forward(self, X, state):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        self._attention_weights = [[None] * len(self.blks) for _ in range(2)]
        for i, blk in enumerate(self.blks):
            X, state = blk(X, state)
            self._attention_weights[0][i] = blk.attention1.attention.attention_weights
            self._attention_weights[1][i] = blk.attention2.attention.attention_weights
        return self.dense(X), state
    
    @property
    def attention_weights(self):
        return self._attention_weights

class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
    def forward(self, enc_X, dec_X, *args):
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(dec_X, dec_state)
    
encoder = TransformerEncoder(len(src_vocab), key_size, query_size, value_size, num_hiddens, norm_shape,
                             ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
decoder = TransformerDecoder(len(tgt_vocab), key_size, query_size, value_size, num_hiddens, norm_shape,
                             ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout)
model = EncoderDecoder(encoder, decoder)